% This function updates the (wage, ability) distribution and the endogenous
% grid over wages
function Probm = update_dist_ss(Prob, Pr, Ss, rs)
    
    global bet0 bet1 tau a b w e pi_theta N_a N_theta N_w N_e w_max w_min alpha xi xi1 ub lb pi_e...
        w theta Int_pi_w sigmae elasts log_e_mu log_e_sigma phi sigma gamma rts w_mult wspread frac0...
        unemp
    
    % Prealloacte matrix for new distribution
    Probm = zeros(size(Prob));
    ln_e_hat = -1/2*log_e_sigma^2;
    
    % Calculate choice probabilities
    [P_S, F_es, wprimeS, ~] = choiceProb(Ss, rs);
   
    % Wage grid is endogenous so that the approximation stays accurate when
    % wage growth occurs
    myDist1  = sum(bsxfun(@times, Prob .* P_S , reshape(pi_theta, [1,1,1,2])),4);
    w_bar_p = sum(myDist1 .* wprimeS, 1:3);
    
    log_w_sigma = sqrt(var(log(w), sum(myDist1,2:3)) + var(log(e), pi_e));
    high_low = 4;
    w2 = zeros(N_w,1);
    w2(end) = high_low * log_w_sigma;
    w2(1) = log_w_sigma*-high_low;
    zstep = (w2(end) - w2(1)) / (N_w-1);
    w2(2:end-1) = w2(1) + zstep * (1:N_w-2);
    wspread = exp(w2);
    w2 = w_bar_p * wspread;
    w_max = max(w2);
    w_min = min(w2);
    
    % Update distribution
    for t = 1:2
        for j = 1:N_a
            for i = 1:N_w
                wage = w(i);
                
                % The wage shock is i.i.d. and lognormally distributed
                % Take midpoints and allocate mass to closest wage grid point
                t1 = unemp + (b   + squeeze(F_es(i,j,:))')  .* f_w(wage);
                midpts = log(w2) - log(t1);
                pts = ([midpts(1,:);(midpts(2:end,:) + midpts(1:end-1,:))/2; Inf Inf Inf] - ln_e_hat)/log_e_sigma;
                mean_prob = diff(normcdf(pts));
                tmp = bsxfun(@times, mean_prob*squeeze(P_S(i,j,:,t)), Pr(j,:));
                pr1 = Prob(i, j) .*  tmp .* pi_theta(t);
                Probm(:, :) = Probm(:, :) + pr1;
          
            end
        end
       
    end
    Probm = Probm/sum(Probm(:));
    w = w2;

end